#include "nvrtc_wrapper.h"
#include "cuda_driver_types.h"
#include <stdexcept>
#include <string>
#include <vector>

extern int __wcuda_version_internal__;

#define WNV_RTC_VERSION_MAJOR (__wcuda_version_internal__ / 1000)
#define WNV_RTC_VERSION_MINOR ((__wcuda_version_internal__ - WNV_RTC_VERSION_MAJOR * 1000)) / 10

#define WNV_RTC_SUPPORT_ARCH_V100 (10 * WNV_TARGET_SM_70) // 700
#define WNV_RTC_SUPPORT_ARCH_T4 (10 * WNV_TARGET_SM_75)   // 750
#define WNV_RTC_SUPPORT_ARCH_A100 (10 * WNV_TARGET_SM_80) // 800

#define SM_ "sm_"
#define COMPUTE_ "compute_"

int getArchVersion(std::vector<const char *> &oarr) {
  int archVersion = WNV_RTC_SUPPORT_ARCH_A100;

  for (auto iter : oarr) {

    std::string option(iter);
    size_t pos;

    try {
      pos = option.find(COMPUTE_);
      if (pos != std::string::npos) {

        unsigned long version =
            std::stoul(option.substr(pos + sizeof(COMPUTE_)));

        if (version == WNV_TARGET_SM_70 || version == WNV_TARGET_SM_75 ||
            version == WNV_TARGET_SM_80) {
          archVersion = version * 10;
        }
        break;
      }

      pos = option.find(SM_);
      if (pos != std::string::npos) {
        unsigned long version =
            std::stoul(option.substr(pos + sizeof(SM_) - 1));

        if (version == WNV_TARGET_SM_70 || version == WNV_TARGET_SM_75 ||
            version == WNV_TARGET_SM_80) {
          archVersion = version * 10;
        }
        break;
      }
    } catch (std::invalid_argument const &ex) {
      return -1;
    } catch (std::out_of_range const &ex) {
      return -1;
    }
  }

  return archVersion;
}

const char *wnvrtcGetErrorString(mcrtcResult result) {
  return mcrtcGetErrorString(result);
}

mcrtcResult wnvrtcGetNumSupportedArchs(int *numArchs) {
  return mcrtcGetNumSupportedArchs(numArchs);
}

mcrtcResult wnvrtcGetSupportedArchs(int *supportedArchs) {
  if (supportedArchs == nullptr) {
    return MCRTC_ERROR_INVALID_INPUT;
  }

  supportedArchs[0] = WNV_RTC_SUPPORT_ARCH_V100;
  supportedArchs[1] = WNV_RTC_SUPPORT_ARCH_T4;
  supportedArchs[2] = WNV_RTC_SUPPORT_ARCH_A100;

  return MCRTC_SUCCESS;
}

mcrtcResult wnvrtcVersion(int *major, int *minor) {
  if (major == nullptr || minor == nullptr) {
    return MCRTC_ERROR_INVALID_INPUT;
  }

  *major = WNV_RTC_VERSION_MAJOR;
  *minor = WNV_RTC_VERSION_MINOR;

  return MCRTC_SUCCESS;
}

mcrtcResult wnvrtcAddNameExpression(mcrtcProgram prog,
                                    const char *const name_expression) {
  return mcrtcAddNameExpression(prog, name_expression);
}

mcrtcResult wnvrtcCompileProgram(mcrtcProgram prog, int numOptions,
                                 const char *const *options) {
  std::vector<const char *> oarr;
  for (int i = 0; i < numOptions; ++i) {
    oarr.push_back(options[i]);
  }
  int MAJOR = WNV_RTC_VERSION_MAJOR;
  int MINOR = WNV_RTC_VERSION_MINOR;
  std::string verMajor = "-D__CUDACC_VER_MAJOR__=" + std::to_string(MAJOR);
  std::string verMinor = "-D__CUDACC_VER_MINOR__=" + std::to_string(MINOR);
  // Add additional compile for cuda
  oarr.push_back("-D__CUDACC_RTC__");
  oarr.push_back("-D__CUDACC__");
  oarr.push_back(verMajor.c_str());
  oarr.push_back(verMinor.c_str());

  // add arch for cuda
  int archVersion = getArchVersion(oarr);
  if (archVersion < 0) {
    return MCRTC_ERROR_INVALID_OPTION;
  }
  std::string arch = "-D__CUDA_ARCH__=" + std::to_string(archVersion);
  oarr.push_back(arch.c_str());

  return mcrtcCompileProgram(prog, oarr.size(), oarr.data());
}

mcrtcResult wnvrtcCreateProgram(mcrtcProgram *prog, const char *src,
                                const char *name, int numHeaders,
                                const char *const *headers,
                                const char *const *includeNames) {
  return mcrtcCreateProgram(prog, src, name, numHeaders, headers, includeNames);
}

mcrtcResult wnvrtcDestroyProgram(mcrtcProgram *prog) {
  return mcrtcDestroyProgram(prog);
}

mcrtcResult wnvrtcGetLoweredName(mcrtcProgram prog,
                                 const char *const name_expression,
                                 const char **lowered_name) {
  return mcrtcGetLoweredName(prog, name_expression, lowered_name);
}

mcrtcResult wnvrtcGetPTX(mcrtcProgram prog, char *ptx) {
  return mcrtcGetBitcode(prog, ptx);
}

mcrtcResult wnvrtcGetPTXSize(mcrtcProgram prog, size_t *ptxSizeRet) {
  return mcrtcGetBitcodeSize(prog, ptxSizeRet);
}

mcrtcResult wnvrtcGetNVVM(mcrtcProgram prog, char *nvvm) {
  return mcrtcGetBitcode(prog, nvvm);
}

mcrtcResult wnvrtcGetNVVMSize(mcrtcProgram prog, size_t *nvvmSizeRet) {
  return mcrtcGetBitcodeSize(prog, nvvmSizeRet);
}

mcrtcResult wnvrtcGetCUBIN(mcrtcProgram prog, char *cubin) {
  return mcrtcGetBitcode(prog, cubin);
}

mcrtcResult wnvrtcGetCUBINSize(mcrtcProgram prog, size_t *cubinSizeRet) {
  return mcrtcGetBitcodeSize(prog, cubinSizeRet);
}

mcrtcResult wnvrtcGetProgramLog(mcrtcProgram prog, char *log) {
  return mcrtcGetProgramLog(prog, log);
}

mcrtcResult wnvrtcGetProgramLogSize(mcrtcProgram prog, size_t *logSizeRet) {
  return mcrtcGetProgramLogSize(prog, logSizeRet);
}